import os
import time
import torch
import numpy as np
import rerun as rr
from lietorch import SE3
import droid_backends
import imageio.v2 as imageio
import shutil
import subprocess

# --- headless safety (no Qt/GUI on cluster) ---
os.environ.setdefault("QT_QPA_PLATFORM", "offscreen")
os.environ.pop("DISPLAY", None)

# ---- helpers to draw a camera frustum as 3D line segments ----
_CAM_POINTS = np.array([
    [0, 0, 0],
    [-1, -1, 1.5],
    [ 1, -1, 1.5],
    [ 1,  1, 1.5],
    [-1,  1, 1.5],
    [-0.5, 1, 1.5],
    [ 0.5, 1, 1.5],
    [ 0, 1.2, 1.5]
], dtype=np.float32)

_CAM_LINES = np.array([
    [1,2], [2,3], [3,4], [4,1], [1,0], [0,2], [3,0], [0,4], [5,7], [7,6]
], dtype=np.int32)

def _camera_lines(scale=0.05):
    pts = (_CAM_POINTS * scale).astype(np.float32)
    return [np.stack([pts[a], pts[b]], axis=0) for a, b in _CAM_LINES]  # list[(2,3)]

def _to_rr_transform(mat4):
    mat4 = mat4.astype(np.float32)
    try:
        return rr.Transform3D(mat4x4=mat4)   # newer API
    except TypeError:
        return rr.Transform3D(mat4)          # older positional form

def _rr_init(app_id: str):
    # Be compatible across rerun-sdk versions
    try:
        rr.init(app_id)  # older positional
        return
    except TypeError:
        pass
    try:
        rr.init(application_id=app_id)  # some versions use 'application_id'
        return
    except TypeError:
        pass
    try:
        rr.init(app_id=app_id)  # newer keyword
    except TypeError:
        rr.init()  # last resort: no app id

def droid_visualization_rerun(
    video,
    device="cuda:0",
    app_id="droid_viz",
    web_port=9876,                # web server port on the node
    record_path=None,             # e.g., "output_fastlivo/debug_rerun/stream.rrd"
    warmup=8,
    filter_thresh=0.5, #filter_thresh=0.005,
    scale=1.0,
):
    """
    Start a headless Rerun Web Viewer on the SLURM node and stream data to it.
    View remotely by SSH port-forwarding to your laptop:
        ssh -N -L 9876:localhost:9876 zihzhu@eu-g4-020.euler.ethz.ch
    Then open http://localhost:9876 in your laptop browser.
    """
    torch.cuda.set_device(device)
    _rr_init(app_id)

    # Start a local web server on the node (no GUI/X11 required)
    started_server = False
    try:
        rr.serve_web(port=web_port, open_browser=False)
        started_server = True
        print(f"[Rerun] Web server on http://127.0.0.1:{web_port}")
    except TypeError:
        # Older versions: fall back to serve() without args (binds a default port)
        try:
            rr.serve()
            started_server = True
            print("[Rerun] Started legacy rr.serve() (port is chosen by Rerun; check logs).")
        except Exception as e:
            print(f"[Rerun] Could not start web server: {e}")

    # Optional: record to file for later playback
    if record_path:
        try:
            rr.save(record_path)
            print(f"[Rerun] Recording to {record_path}")
        except Exception as e:
            print(f"[Rerun] Recording not started: {e}")

    # Global scene axes (best-effort across versions)
    try:
        rr.log("world", rr.ViewCoordinates.RDF, timeless=True)  # OpenCV-ish axes
    except Exception:
        pass

    filter_thresh = float(filter_thresh)
    cam_segs_local = _camera_lines(scale=0.1)  # larger for current cam

    try:
        while True:
            # Check which frames need updating
            with video.get_lock():
                dirty_index, = torch.where(video.dirty.clone())

            if len(dirty_index) == 0:
                time.sleep(0.01)
                continue

            # Mark processed
            video.dirty[dirty_index] = False

            high_res = False
            # Fetch tensors
            poses = torch.index_select(video.poses, 0, dirty_index)
            Ps = SE3(poses).inv().matrix().cpu().numpy()  # [N, 4, 4] world_T_cam
            images = torch.index_select(video.images, 0, dirty_index.cpu())
            intrinsics = video.intrinsics[0].clone()
            if high_res:
                disps = torch.index_select(video.disps_up, 0, dirty_index.cpu())
                
                # disps = torch.index_select(video.disps_up, 0, dirty_index)
                # disps = disps.cpu()
                images = images.cpu()[..., :, :].permute(0, 2, 3, 1) / 255.0  # [N, H, W, 3]
                intrinsics*=8
            else:
                disps = torch.index_select(video.disps, 0, dirty_index)
                images = images.cpu()[..., 3::8, 3::8].permute(0, 2, 3, 1) / 255.0  # [N, H, W, 3]
            # disps_up = torch.index_select(video.disps_up, 0, dirty_index)
            
            # images = images.cpu()[..., 3::8, 3::8].permute(0, 2, 3, 1) / 255.0  # [N, H, W, 3]
            # points = droid_backends.iproj(SE3(poses).inv().data, disps, intrinsics).cpu()  # [N,H,W,3]

            # Consistency filter
            thresh = filter_thresh * torch.ones_like(disps.mean(dim=[1, 2]))
            # print('video.poses', video.poses.shape)
            # print('video.disps', video.disps.shape)
            # print('intrinsics', intrinsics.shape)
            # print('dirty_index', dirty_index.shape)
            # print('thresh', thresh.shape)
            if high_res:
                # disps_up_clone = video.disps_up.clone()
                count = droid_backends.depth_filter(video.poses, video.disps_up, intrinsics, dirty_index, thresh)
            else:
                count = droid_backends.depth_filter(video.poses, video.disps, intrinsics, dirty_index, thresh)
            # print(count)
            count = count.cpu()
            disps = disps.cpu()
            
            masks = ((count >= 2) & (disps > .5 * disps.mean(dim=[1, 2], keepdim=True)))  # [N,H,W]
            # masks = ((count >= 2) & (disps > .1 * disps.mean(dim=[1, 2], keepdim=True)))
            # masks = torch.ones_like(masks) # no filter
            
            if False:
                omnidc_out_dir = 'tmp_omnidc_output'
                save_dir = 'tmp_omnidc'
                if os.path.exists(save_dir):
                    shutil.rmtree(save_dir)
                if os.path.exists(omnidc_out_dir):
                    shutil.rmtree(omnidc_out_dir)
                os.makedirs(save_dir, exist_ok=True)
                # --- new saving part ---
                for i, idx in enumerate(dirty_index.cpu().numpy()):
                    # RGB: convert to uint8 [0,255]
                    rgb_img = (images[i].numpy() * 255).astype(np.uint8)
                    imageio.imwrite(os.path.join(save_dir, f"rgb_{idx:05d}.png"), rgb_img)

                    # Depth: apply mask (invalid=0.0)
                    depth = (1/disps[i]).numpy()
                    mask = masks[i].numpy()
                    depth_masked = np.where(mask, depth, 0.0).astype(np.float32)
                    np.save(os.path.join(save_dir, f"depth_{idx:05d}.npy"), depth_masked)
                    
                filepath  = os.path.join(save_dir, f"FLAG")

                open(filepath, 'w').close()

                # env = os.environ.copy()
                # script = r"""
                # set -e
                # source ~/.bashrc
                # conda activate torch110
                # cd /cluster/project/cvg/zihzhu/OMNI-DC/src
                # bash testing_scripts/demo_folder.sh
                # """
                # subprocess.run(script, shell=True, executable="/bin/bash", env=env, check=True)
                # os.system("cd /cluster/project/cvg/zihzhu/OMNI-DC/src;   sh testing_scripts/demo_folder.sh;")
                
                while True:
                    npy_path  = os.path.join(omnidc_out_dir, f"depth_pred_{dirty_index.cpu().numpy()[-1]:05d}.npy")
                    if os.path.exists(npy_path):
                        break
                    time.sleep(0.1)
                print("OMNI-DC received")    
                for i, idx in enumerate(dirty_index.cpu().numpy()):
                    npy_path  = os.path.join(omnidc_out_dir, f"depth_pred_{idx:05d}.npy")
                    png_path  = os.path.join(omnidc_out_dir, f"depth_pred_{idx:05d}.png")

                    if not os.path.isfile(npy_path):
                        # not fatal; skip if OMNI-DC didn't produce this frame
                        # print(f"[OMNI-DC] missing: {npy_path}")
                        continue

                    # --- read depth (float) ---
                    d = np.load(npy_path)  # shape [h,w] or [h,w,1]
                    if d.ndim == 3 and d.shape[-1] == 1:
                        d = d[..., 0]
                    d = d.astype(np.float32)
                    disps[i] = 1/torch.from_numpy(d)
                    # # --- resize to (H,W) if needed (keep metric meaning) ---
                    # if d.shape != (H, W):
                    #     dt = torch.from_numpy(d)[None, None]  # [1,1,h,w]
                    #     d  = F.interpolate(dt, size=(H, W), mode="bilinear", align_corners=False)[0,0].cpu().numpy()

                    # depth_pred_np[idx] = d
                    # loaded += 1

                    
                if os.path.exists(save_dir):
                    shutil.rmtree(save_dir)
                        
                masks = torch.ones_like(masks) # no filter

            points = droid_backends.iproj(SE3(poses).inv().data, disps.cuda(), intrinsics).cpu()  # [N,H,W,3]

            
            # Log per keyframe
            for i in range(len(dirty_index)):
                kf_idx = int(dirty_index[i].item())
                world_from_cam = Ps[i]
                cam_path = f"world/cameras/{kf_idx:06d}"

                # Pose
                try:
                    rr.log(cam_path, _to_rr_transform(world_from_cam))
                except Exception:
                    pass

                # Frustum lines
                try:
                    segs_world = []
                    for seg in cam_segs_local:
                        seg_h = np.concatenate([seg, np.ones((2, 1), dtype=np.float32)], axis=1)  # (2,4)
                        seg_w = (world_from_cam @ seg_h.T).T[:, :3]
                        segs_world.append(seg_w)
                    rr.log(f"{cam_path}/frustum", rr.LineStrips3D(np.stack(segs_world, axis=0)))  # (S,2,3)
                except Exception:
                    pass

                # Colored point cloud
                mask = masks[i].reshape(-1).numpy()
                pts = points[i].reshape(-1, 3)[mask].numpy()
                clr = images[i].reshape(-1, 3)[mask].numpy()
                # still visualize high_res, but downsample by 16, otherwise it's too dense, rerun very slow
                if high_res:
                    pts=pts[::512, :]
                    clr=clr[::512, :]
                if pts.size > 0:
                    try:
                        rr.log(f"world/points/{kf_idx:06d}", rr.Points3D(pts, colors=(clr * 255).astype(np.uint8)))
                    except Exception:
                        rr.log(f"world/points/{kf_idx:06d}", rr.Points3D(pts))

    except KeyboardInterrupt:
        pass
